{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aee49696",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp py2pyi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83a6b2d0",
   "metadata": {},
   "source": [
    "# Create delegated pyi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ab07e20",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3433906e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import ast, sys, inspect, re, os, importlib.util, importlib.machinery\n",
    "\n",
    "from ast import parse, unparse\n",
    "from inspect import signature, getsource\n",
    "from fastcore.utils import *\n",
    "from fastcore.meta import delegates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d9f282d-65cb-44bf-adc8-62fd24ea550b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import copy\n",
    "from fastcore.test import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65bfdfa1",
   "metadata": {},
   "source": [
    "## Basics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "444ab985",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def imp_mod(module_path, package=None):\n",
    "    \"Import dynamically the module referenced in `fn`\"\n",
    "    module_path = str(module_path)\n",
    "    module_name = os.path.splitext(os.path.basename(module_path))[0]\n",
    "    spec = importlib.machinery.ModuleSpec(module_name, None, origin=module_path)\n",
    "    module = importlib.util.module_from_spec(spec)\n",
    "    spec.loader = importlib.machinery.SourceFileLoader(module_name, module_path)\n",
    "    if package is not None: module.__package__ = package\n",
    "    module.__file__ = os.path.abspath(module_path)\n",
    "    spec.loader.exec_module(module)\n",
    "    return module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92ebf6b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = Path('test_py2pyi.py')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1aa9f23f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mod = imp_mod(fn)\n",
    "a = mod.A()\n",
    "a.h()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa277c0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _get_tree(mod):\n",
    "    return parse(getsource(mod))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9751182f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree = _get_tree(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72cbc154",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def __repr__(self:ast.AST):\n",
    "    return unparse(self)\n",
    "\n",
    "@patch\n",
    "def _repr_markdown_(self:ast.AST):\n",
    "    return f\"\"\"```python\n",
    "{self!r}\n",
    "```\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a61ebb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for o in enumerate(tree.body): print(o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb83c32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "def f(a: int, b: str='a') -> str:\n",
       "    \"\"\"I am f\"\"\"\n",
       "    return 1\n",
       "```"
      ],
      "text/plain": [
       "def f(a: int, b: str='a') -> str:\n",
       "    \"\"\"I am f\"\"\"\n",
       "    return 1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node = tree.body[4]\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f971f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "functypes = (ast.FunctionDef,ast.AsyncFunctionDef)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c96ad101",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "isinstance(node, functypes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a137fab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _deco_id(d:Union[ast.Name,ast.Attribute])->bool:\n",
    "    \"Get the id for AST node `d`\"\n",
    "    return d.id if isinstance(d, ast.Name) else d.func.id\n",
    "\n",
    "def has_deco(node:Union[ast.FunctionDef,ast.AsyncFunctionDef], name:str)->bool:\n",
    "    \"Check if a function node `node` has a decorator named `name`\"\n",
    "    return any(_deco_id(d)==name for d in getattr(node, 'decorator_list', []))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7dc103d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nm = 'delegates'\n",
    "has_deco(node, nm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e887842",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@delegates(f)\n",
       "def g(c, d: X, **kwargs) -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    return 2\n",
       "```"
      ],
      "text/plain": [
       "@delegates(f)\n",
       "def g(c, d: X, **kwargs) -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    return 2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node = tree.body[5]\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d463f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "has_deco(node, nm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3f32ec7",
   "metadata": {},
   "source": [
    "## Function processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee690f4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _proc_body   (node, mod): print('_proc_body', type(node))\n",
    "def _proc_func   (node, mod): print('_proc_func', type(node))\n",
    "def _proc_class  (node, mod): print('_proc_class', type(node))\n",
    "def _proc_patched(node, mod): print('_proc_patched', type(node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fd0818d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _get_proc(node):\n",
    "    if isinstance(node, ast.ClassDef): return _proc_class\n",
    "    if not isinstance(node, functypes): return None\n",
    "    if not has_deco(node, 'delegates'): return _proc_body\n",
    "    if has_deco(node, 'patch'): return _proc_patched\n",
    "    return _proc_func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc1a0b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_tree(tree, mod):\n",
    "    for node in tree.body:\n",
    "        proc = _get_proc(node)\n",
    "        if proc: proc(node, mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8721e51d-858e-4f4c-9de7-db7ce8c48ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _clean_patched_node(node):\n",
    "    \"Clean the patched node in-place.\"\n",
    "    # When moving a patched node to its parent, we no longer need the patch decorator and parent annotation.\n",
    "    node.decorator_list = [deco for deco in node.decorator_list if getattr(deco, \"id\", None) != \"patch\"]\n",
    "    node.args.args[0].annotation = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ba19904-091c-4872-b0bc-8a79c7268ec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "parent_node = copy.deepcopy(tree.body[7])\n",
    "patched_node = copy.deepcopy(tree.body[10])\n",
    "test_is(has_deco(patched_node, \"patch\"), True)\n",
    "test_eq(str(patched_node.args.args[0].annotation), parent_node.name)\n",
    "\n",
    "_clean_patched_node(patched_node)\n",
    "test_is(has_deco(patched_node, \"patch\"), False)\n",
    "test_eq(patched_node.args.args[0].annotation, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71c676fb-dea7-466b-965c-36ded7fabf87",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _is_empty_class(node):\n",
    "    if not isinstance(node, ast.ClassDef): return False\n",
    "    if len(node.body) != 1: return False\n",
    "    child = node.body[0]\n",
    "    if isinstance(child, ast.Pass): return True\n",
    "    if isinstance(child, ast.Expr) and isinstance(child.value, (ast.Ellipsis, ast.Str)): return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd27cf2c-caa0-4bbd-92d8-8433fe257a29",
   "metadata": {},
   "outputs": [],
   "source": [
    "empty_cls1, empty_cls2, empty_cls3 = ast.parse('''\n",
    "class A: \n",
    "    \"\"\"An empty class.\"\"\"\n",
    "class B: \n",
    "    pass\n",
    "class C: \n",
    "    ...\n",
    "''').body\n",
    "\n",
    "test_is(_is_empty_class(empty_cls1), True)\n",
    "test_is(_is_empty_class(empty_cls2), True)\n",
    "test_is(_is_empty_class(empty_cls3), True)\n",
    "\n",
    "non_empty_cls, empty_func = ast.parse('''\n",
    "class A: \n",
    "    a = 1\n",
    "def f():\n",
    "    ...\n",
    "''').body\n",
    "test_is(_is_empty_class(non_empty_cls), False)\n",
    "test_is(_is_empty_class(empty_func), False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f34cb57-c780-44a3-b7e8-99d763e51211",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _add_patched_node_to_parent(node, parent):\n",
    "    \"Add a patched node to its parent.\"\n",
    "    # if the patch node updates an existing class method, let's replace it.\n",
    "    for i, child in enumerate(parent.body):\n",
    "        if hasattr(child, \"name\") and child.name == node.name:\n",
    "            parent.body[i] = node\n",
    "            return\n",
    "    # if we've made it this far the patched node must be new, so we append it to the parent's list of children\n",
    "    if _is_empty_class(parent): parent.body = [node]\n",
    "    else: parent.body.append(node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7625b701-28dd-4a9f-84f8-e56b2bd68f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we could have reused `parent_node` and `patched_node` from the previous cells.\n",
    "# copying them here allows us to run this cell multiple times which makes it a little easier to write tests.\n",
    "\n",
    "parent_node = copy.deepcopy(tree.body[7])\n",
    "patched_node = copy.deepcopy(tree.body[11])\n",
    "test_eq(len(parent_node.body),1)\n",
    "_add_patched_node_to_parent(patched_node, parent_node)\n",
    "test_eq(len(parent_node.body),2)\n",
    "test_eq(parent_node.body[-1], patched_node)\n",
    "\n",
    "# patched node replaces an existing class method (A.h)\n",
    "patched_h_node = ast.parse(\"\"\"\n",
    "@patch\n",
    "def h(self: A, *args, **kwargs):\n",
    "    ...\n",
    "\"\"\", mode='single').body[0]\n",
    "\n",
    "_add_patched_node_to_parent(patched_h_node, parent_node)\n",
    "test_eq(len(parent_node.body), 2)\n",
    "test_eq(parent_node.body[0], patched_h_node)\n",
    "\n",
    "# patched node is added to an empty class\n",
    "empty_cls, patched_node = ast.parse('''\n",
    "class Z: \n",
    "    \"\"\"An empty class.\"\"\"\n",
    "\n",
    "@patch\n",
    "def a(self: Z, *args, **kwargs):\n",
    "    ...\n",
    "''').body\n",
    "\n",
    "test_eq(len(empty_cls.body), 1)\n",
    "test_ne(empty_cls.body[0], patched_node)\n",
    "_add_patched_node_to_parent(patched_node, empty_cls)\n",
    "test_eq(len(empty_cls.body), 1)\n",
    "test_eq(empty_cls.body[0], patched_node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "689de555-af4a-4abd-ac6c-2296c2ca3a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_patches(tree, mod):\n",
    "    \"Move all patched methods to their parents.\"\n",
    "    class_nodes = {}  # {class_name: position in node tree}\n",
    "    i = 0\n",
    "    while i < len(tree.body):\n",
    "        node = tree.body[i]\n",
    "        if isinstance(node, ast.ClassDef):\n",
    "            # patched nodes can access their parent using different scopes:\n",
    "            #  - local scope (e.g. self: A)\n",
    "            #  - modular scope (e.g. self: example.A) [used by @delegates]\n",
    "            class_nodes.update({node.name: i, f\"{mod.__name__}.{node.name}\": i})\n",
    "        elif isinstance(node, ast.FunctionDef) and has_deco(node, \"patch\"):\n",
    "            annotation = node.args.args[0].annotation\n",
    "            # a patched node can have 1 or more parents\n",
    "            parents = annotation.elts if hasattr(annotation, \"elts\") else [annotation]\n",
    "            parents = list(map(str, parents))\n",
    "            parents_in_mod = [parent for parent in parents if parent in class_nodes]\n",
    "            # we can move the patched node if at least one parent lives in the current module\n",
    "            if parents_in_mod:\n",
    "                _clean_patched_node(node)\n",
    "                for parent in parents_in_mod:\n",
    "                    parent_node = tree.body[class_nodes[parent]]\n",
    "                    _add_patched_node_to_parent(node, parent_node)\n",
    "                tree.body.pop(i)\n",
    "                i -= 1  # as we've removed the patched node from the tree we need to decrement the loop counter\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01816c3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_mod(mod):\n",
    "    tree = _get_tree(mod)\n",
    "    _proc_tree(tree, mod)\n",
    "    _proc_patches(tree, mod)\n",
    "    return tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef401c95-c1f5-48a9-90e9-bee534662fc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_body <class 'ast.FunctionDef'>\n",
      "_proc_func <class 'ast.FunctionDef'>\n",
      "_proc_body <class 'ast.FunctionDef'>\n",
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_patched <class 'ast.FunctionDef'>\n",
      "_proc_patched <class 'ast.FunctionDef'>\n",
      "_proc_body <class 'ast.FunctionDef'>\n"
     ]
    }
   ],
   "source": [
    "raw_tree = _get_tree(mod)\n",
    "processed_tree = _proc_mod(mod)\n",
    "n_raw_tree_nodes = len(raw_tree.body)\n",
    "# mod contains 3 patch methods so our processed_tree should have 3 less nodes \n",
    "test_eq(len(processed_tree.body), n_raw_tree_nodes-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eee5c5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_body <class 'ast.FunctionDef'>\n",
      "_proc_func <class 'ast.FunctionDef'>\n",
      "_proc_body <class 'ast.FunctionDef'>\n",
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_patched <class 'ast.FunctionDef'>\n",
      "_proc_patched <class 'ast.FunctionDef'>\n",
      "_proc_body <class 'ast.FunctionDef'>\n"
     ]
    }
   ],
   "source": [
    "_proc_mod(mod);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e38a0d6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'g'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85c1187d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function test_py2pyi.g(c, d: test_py2pyi.X, *, b: str = 'a') -> str>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sym = getattr(mod, node.name)\n",
    "sym"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdbbe50d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(c, d: test_py2pyi.X, *, b: str = 'a') -> str\n"
     ]
    }
   ],
   "source": [
    "sig = signature(sym)\n",
    "print(sig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "936f551a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def sig2str(sig):\n",
    "    s = str(sig)\n",
    "    s = re.sub(r\"<class '(.*?)'>\", r'\\1', s)\n",
    "    s = re.sub(r\"dynamic_module\\.\", \"\", s)\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f63b9ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def ast_args(func):\n",
    "    sig = signature(func)\n",
    "    return ast.parse(f\"def _{sig2str(sig)}: ...\").body[0].args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f32e1d00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "c, d: test_py2pyi.X, *, b: str='a'\n",
       "```"
      ],
      "text/plain": [
       "c, d: test_py2pyi.X, *, b: str='a'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "newargs = ast_args(sym)\n",
    "newargs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b343db0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "c, d: X, **kwargs\n",
       "```"
      ],
      "text/plain": [
       "c, d: X, **kwargs"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node.args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d48fb797",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@delegates(f)\n",
       "def g(c, d: test_py2pyi.X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    return 2\n",
       "```"
      ],
      "text/plain": [
       "@delegates(f)\n",
       "def g(c, d: test_py2pyi.X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    return 2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node.args = newargs\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7e1ee39",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _body_ellip(n: ast.AST):\n",
    "    stidx = 1 if isinstance(n.body[0], ast.Expr) and isinstance(n.body[0].value, ast.Str) else 0\n",
    "    n.body[stidx:] = [ast.Expr(ast.Constant(...))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47bfb08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@delegates(f)\n",
       "def g(c, d: test_py2pyi.X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    ...\n",
       "```"
      ],
      "text/plain": [
       "@delegates(f)\n",
       "def g(c, d: test_py2pyi.X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    ..."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_body_ellip(node)\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "741c3d0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _update_func(node, sym):\n",
    "    \"\"\"Replace the parameter list of the source code of a function `f` with a different signature.\n",
    "    Replace the body of the function with just `pass`, and remove any decorators named 'delegates'\"\"\"\n",
    "    # sym_args contains the complete set of node args including any delegates kwargs.\n",
    "    # when adding the delegates kwargs any annotation with a user-defined type is given a fully qualified name (e.g. mod.submod.func).\n",
    "    # unfortunately the fully qualified names don't match the import statements in the node tree.\n",
    "    # this can break downstream applications such as \"jump to definition\" in IDEs.\n",
    "    # to resolve this issue we replace the annotations of user-defined types with the original annotations.\n",
    "    sym_args = ast_args(sym)\n",
    "    node_args = {arg.arg: arg.annotation for arg in node.args.args + node.args.kwonlyargs}\n",
    "    for arg in sym_args.args + sym_args.kwonlyargs:\n",
    "        arg.annotation = node_args.get(arg.arg, arg.annotation)\n",
    "    node.args = sym_args\n",
    "    _body_ellip(node)\n",
    "    node.decorator_list = [d for d in node.decorator_list if _deco_id(d) != 'delegates']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8d675dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@delegates(f)\n",
       "def g(c, d: X, **kwargs) -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    return 2\n",
       "```"
      ],
      "text/plain": [
       "@delegates(f)\n",
       "def g(c, d: X, **kwargs) -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    return 2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = _get_tree(mod)\n",
    "node = tree.body[5]\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f3c3f00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "def g(c, d: X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    ...\n",
       "```"
      ],
      "text/plain": [
       "def g(c, d: X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    ..."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_update_func(node, sym)\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c148b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_body(node, mod): _body_ellip(node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20a4bd05",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_func(node, mod):\n",
    "    sym = getattr(mod, node.name)\n",
    "    _update_func(node, sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47473351",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_class <class 'ast.ClassDef'>\n",
      "_proc_patched <class 'ast.FunctionDef'>\n",
      "_proc_patched <class 'ast.FunctionDef'>\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "def g(c, d: X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    ...\n",
       "```"
      ],
      "text/plain": [
       "def g(c, d: X, *, b: str='a') -> str:\n",
       "    \"\"\"I am g\"\"\"\n",
       "    ..."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = _proc_mod(mod)\n",
    "tree.body[5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8ed1c88",
   "metadata": {},
   "source": [
    "## Patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e8f3835",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@patch\n",
       "@delegates(j)\n",
       "def k(self: (A, B), b: bool=False, **kwargs):\n",
       "    return 1\n",
       "```"
      ],
      "text/plain": [
       "@patch\n",
       "@delegates(j)\n",
       "def k(self: (A, B), b: bool=False, **kwargs):\n",
       "    return 1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = _get_tree(mod)\n",
    "node = tree.body[9]\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9ff7554",
   "metadata": {},
   "outputs": [],
   "source": [
    "ann = node.args.args[0].annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc01f0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if hasattr(ann, 'elts'): ann = ann.elts[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5832e546",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'A'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nm = ann.id\n",
    "nm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee82124f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls = getattr(mod, nm)\n",
    "sym = getattr(cls, node.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3973e41c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"(self: (test_py2pyi.A, test_py2pyi.B), b: bool = False, *, d: str = 'a')\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sig2str(signature(sym))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83372276",
   "metadata": {},
   "outputs": [],
   "source": [
    "_update_func(node, sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b7ec59e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@patch\n",
       "def k(self: (A, B), b: bool=False, *, d: str='a'):\n",
       "    ...\n",
       "```"
      ],
      "text/plain": [
       "@patch\n",
       "def k(self: (A, B), b: bool=False, *, d: str='a'):\n",
       "    ..."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0743766",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_patched(node, mod):\n",
    "    ann = node.args.args[0].annotation\n",
    "    if hasattr(ann, 'elts'): ann = ann.elts[0]\n",
    "    cls = getattr(mod, ann.id)\n",
    "    sym = getattr(cls, node.name)\n",
    "    _update_func(node, sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fa82069",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "@patch\n",
       "@delegates(j)\n",
       "def k(self: (A, B), b: bool=False, **kwargs):\n",
       "    return 1\n",
       "```"
      ],
      "text/plain": [
       "@patch\n",
       "@delegates(j)\n",
       "def k(self: (A, B), b: bool=False, **kwargs):\n",
       "    return 1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = _get_tree(mod)\n",
    "tree.body[9]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2beb88d2",
   "metadata": {},
   "source": [
    "## Class and file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cbba27f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "class A:\n",
       "\n",
       "    @delegates(j)\n",
       "    def h(self, b: bool=False, **kwargs):\n",
       "        a = 1\n",
       "        return a\n",
       "```"
      ],
      "text/plain": [
       "class A:\n",
       "\n",
       "    @delegates(j)\n",
       "    def h(self, b: bool=False, **kwargs):\n",
       "        a = 1\n",
       "        return a"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = _get_tree(mod)\n",
    "node = tree.body[7]\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b648a5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[@delegates(j)\n",
       " def h(self, b: bool=False, **kwargs):\n",
       "     a = 1\n",
       "     return a]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node.body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07ba304d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _proc_class(node, mod):\n",
    "    cls = getattr(mod, node.name)\n",
    "    _proc_tree(node, cls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e896dcfa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "class A:\n",
       "\n",
       "    def h(self, b: bool=False, *, d: str='a'):\n",
       "        ...\n",
       "\n",
       "    def k(self, b: bool=False, *, d: str='a'):\n",
       "        ...\n",
       "\n",
       "    def m(self, b: bool=False, *, d: str='a'):\n",
       "        ...\n",
       "\n",
       "    def n(self, b: bool=False, **kwargs):\n",
       "        \"\"\"No delegates here mmm'k?\"\"\"\n",
       "        ...\n",
       "```"
      ],
      "text/plain": [
       "class A:\n",
       "\n",
       "    def h(self, b: bool=False, *, d: str='a'):\n",
       "        ...\n",
       "\n",
       "    def k(self, b: bool=False, *, d: str='a'):\n",
       "        ...\n",
       "\n",
       "    def m(self, b: bool=False, *, d: str='a'):\n",
       "        ...\n",
       "\n",
       "    def n(self, b: bool=False, **kwargs):\n",
       "        \"\"\"No delegates here mmm'k?\"\"\"\n",
       "        ..."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = _proc_mod(mod)\n",
    "tree.body[7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d149d1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def create_pyi(fn, package=None):\n",
    "    \"Convert `fname.py` to `fname.pyi` by removing function bodies and expanding `delegates` kwargs\"\n",
    "    fn = Path(fn)\n",
    "    mod = imp_mod(fn, package=package)\n",
    "    tree = _proc_mod(mod)\n",
    "    res = unparse(tree)\n",
    "    fn.with_suffix('.pyi').write_text(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a4e6fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_pyi(fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7067e29f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fn = Path('/Users/jhoward/git/fastcore/fastcore/docments.py')\n",
    "# create_pyi(fn, 'fastcore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "075cf465",
   "metadata": {},
   "source": [
    "## Script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "365271bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from fastcore.script import call_parse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08be00d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@call_parse\n",
    "def py2pyi(fname:str,  # The file name to convert\n",
    "           package:str=None  # The parent package\n",
    "          ):\n",
    "    \"Convert `fname.py` to `fname.pyi` by removing function bodies and expanding `delegates` kwargs\"\n",
    "    create_pyi(fname, package)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe8e3665",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@call_parse\n",
    "def replace_wildcards(\n",
    "    # Path to the Python file to process\n",
    "    path: str):\n",
    "    \"Expand wildcard imports in the specified Python file.\"\n",
    "    path = Path(path)\n",
    "    path.write_text(expand_wildcards(path.read_text()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df973d4e",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad32b076",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc0e8a0a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
